import copy
import os
import re
import struct
import xml.etree.ElementTree as ET

from google.protobuf.descriptor import FieldDescriptor as FD
from api_server.dragonmaze.protopy.game_message_pb2 import *
from api_server.dragonmaze.protopy.account_message_pb2 import *
from api_server.mergical.protopy.game_server_message_pb2 import *
from api_server.townest.protopy.ActivityModuleMsg_pb2 import *
from api_server.townest.protopy.AdsModuleMsg_pb2 import *
from api_server.townest.protopy.BoxModuleMsg_pb2 import *
from api_server.townest.protopy.BuffModuleMsg_pb2 import *
from api_server.townest.protopy.BuildingModuleMsg_pb2 import *
from api_server.townest.protopy.CoreModuleMsg_pb2 import *
from api_server.townest.protopy.DayRewardModuleMsg_pb2 import *
from api_server.townest.protopy.ErrorCodeEnums_pb2 import *
from api_server.townest.protopy.FriendsModuleMsg_pb2 import *
from api_server.townest.protopy.GatewayModuleMsg_pb2 import *
from api_server.townest.protopy.GMModuleMsg_pb2 import *
from api_server.townest.protopy.ItemModuleMsg_pb2 import *
from api_server.townest.protopy.LevelsModuleMsg_pb2 import *
from api_server.townest.protopy.TaskModuleMsg_pb2 import *


def pb2dict(obj):
    """
    Takes a ProtoBuf Message obj and convertes it to a dict.
    """
    adict = {}
    if not obj.IsInitialized():
        return None
    for field in obj.DESCRIPTOR.fields:
        # if not getattr(obj, field.name):
        #     continue
        if not field.label == FD.LABEL_REPEATED:
            if not field.type == FD.TYPE_MESSAGE:
                adict[field.name] = getattr(obj, field.name)
            else:
                value = pb2dict(getattr(obj, field.name))
                if value:
                    adict[field.name] = value
        else:
            if field.type == FD.TYPE_MESSAGE:
                adict[field.name] = \
                    [pb2dict(v) for v in getattr(obj, field.name)]
            else:
                adict[field.name] = [v for v in getattr(obj, field.name)]
    return adict


def pb2seria(message_info: dict, data: dict, repeated: bool = False):
    if "name" in data:
        trans_name = str(data.get('name'))
    elif "msgId" in data:
        trans_name = str(data.get("msgId"))
        data = data.get("msgBody")
    else:
        trans_name = "townest"
    message = message_info.get(trans_name)
    message = eval(message)()
    # TODO setattr 不允许给 repeated字段/composite-混合结构字段 直接赋值
    if not repeated:
        for k, v in data.items():
            if isinstance(v, dict):
                for key, value in v.items():
                    setattr(eval('message.' + k), key, value)
            elif isinstance(v, list):
                method = eval('message.' + str(k))
                for i in v:
                    if isinstance(i, dict):
                        obj = getattr(method, 'add')()
                        for key, value in i.items():
                            setattr(obj, key, value)
                    else:
                        getattr(method, 'append')(i)
            else:
                setattr(message, k, v)
    else:
        for k, v in data.items():
            if isinstance(v, list):
                method = eval('message.' + str(k))
                for i in v:
                    if isinstance(i, dict):
                        obj = getattr(method, 'add')()
                        for key, value in i.items():
                            setattr(obj, key, value)
                    else:
                        getattr(method, 'append')(i)
            else:
                setattr(message, k, v)

    return message.SerializeToString()


def pb2parse(project_name, file_name, result):
    with open(file_name, encoding='utf-8') as file_object:
        file_context = file_object.read()
        if project_name == "dragonmaze":
            paragraphs = re.findall(r"message.*?default.*?\}", file_context, re.S)
            for paragraph in paragraphs:
                message = re.findall(r"message (.+?) .*", paragraph)
                name = re.findall(r".*default = (.+?)\].*", paragraph)
                result[name[0]] = message[0]
        elif project_name == "mergical":
            paragraphs = re.findall(r"message [R|Q|P].*?\}", file_context, re.S)
            for paragraph in paragraphs:
                message = re.findall(r"message ([R|Q|P].+?)[ |\{].*", paragraph)
                result[message[0]] = message[0]
        elif project_name == "townest":
            paragraphs = re.findall(r"message (.+?)Request*", file_context, re.S)
            for each in paragraphs:
                request = re.search(r'(?<=message )\w+', each)
                if not request:
                    continue
                name = request.group(0)
                code = re.findall(r"_SUB_%s = (\d+);" % name, file_context, re.I)
                if code:
                    result["request"][code[0]] = name + "Request"
                    result["response"][code[0]] = name + "Response"


def get_proto_filename(file_dir, result):
    file_dir = file_dir + '/proto'
    file_name = []
    for files in os.listdir(file_dir):
        if os.path.splitext(files)[1] == '.proto':
            file_name.append(files)
    for file in file_name:
        pb2parse("townest", os.path.dirname(os.path.abspath(__file__)) + "/townest/proto/" + file, result)


class ParserXml(object):

    api_datas = {}
    content = b''
    param_bytes = b''

    def xmlParser(self, filename):
        tree = ET.ElementTree(file=filename)
        # 获得根节点
        root = tree.getroot()
        for node in root.iter(tag='node'):
            node_attrib_alias = node.attrib.get("alias")

            # messageId
            for pros in node.findall('properties/property'):
                pro_attrib = pros.attrib
                if pro_attrib.get("key") == "messageId":
                    node_attrib_alias = pro_attrib.get("value")

            self.api_datas[node_attrib_alias] = {}

            # 参数列表or复杂类型组成
            field_list = []
            for field in node.iter(tag='field'):
                field_attrib = field.attrib
                field_list.append(field_attrib)
                field_attrib_bak = copy.deepcopy(field_attrib)
                if field_attrib_bak.get("alias") == 'list':
                    param_name = field_attrib_bak.get("name")
                    field_attrib_bak["alias"] = field_attrib_bak.pop("generic")
                    self.api_datas[node_attrib_alias][param_name] = [field_attrib_bak]

            self.api_datas[node_attrib_alias]["params"] = field_list

        return self.api_datas

    def param2serial(self, data, msg_id, key="params"):
        for param in self.api_datas[msg_id][key]:
            param_type = param.get("alias")
            param_name = param.get("name")
            value = data.get(param_name)
            if param_type == 'string':
                value = value.encode("utf-8")
                value_len = len(value)
                self.param_bytes += struct.pack('>h', value_len) + struct.pack(str(value_len) + 's', value)
            elif param_type == 'int':
                self.param_bytes += struct.pack('>i', value)
            elif param_type == 'long':
                self.param_bytes += struct.pack('>q', value)
            elif param_type == 'short':
                self.param_bytes += struct.pack('>h', value)
            elif param_type == 'list':  # 类型为list的复杂类型
                value_len = len(value)
                list_type = param.get("generic")
                self.param_bytes += struct.pack('>h', value_len)
                for each_value in value:
                    if list_type in self.api_datas:  # TODO 请求参数中list类型均为复杂类型
                        self.param2serial(each_value, list_type)
                    else:  # TODO 请求参数中list类型均为简单类型 未验证
                        self.param2serial(each_value, msg_id, param_name)
            else:  # 类型为复杂类型
                self.param2serial(value, param_type)

    def xml2serial(self, raw_data):
        msgId = raw_data["messageId"]
        msgId_bytes = struct.pack('>h', msgId)
        self.param2serial(raw_data, str(msgId))
        param_bytes_len_bytes = struct.pack('>h', len(self.param_bytes))

        send_bytes = msgId_bytes + param_bytes_len_bytes + self.param_bytes
        return send_bytes

    def serial2dict(self, resp_id, resJson, key="params"):
        param_msg = self.api_datas[str(resp_id)][key]
        for param in param_msg:
            param_type = param.get("alias")
            param_name = param.get("name")
            if param_type == 'string':
                name_len = struct.unpack('>h', self.content[0:2])[0]
                value = struct.unpack(str(name_len) + 's', self.content[2:name_len+2])[0].decode("utf-8")
                self.content = self.content[name_len+2:]
            elif param_type == 'long':
                value = struct.unpack('>q', self.content[0:8])[0]
                self.content = self.content[8:]
            elif param_type == 'int':
                value = struct.unpack('>i', self.content[0:4])[0]
                self.content = self.content[4:]
            elif param_type == 'short':
                value = struct.unpack('>h', self.content[0:2])[0]
                self.content = self.content[2:]
            elif param_type == 'byte':
                value = struct.unpack('>b', self.content[0:1])[0]
                self.content = self.content[1:]
            elif param_type == 'list':
                value = []
                param_type = param.get("generic")
                value_len = struct.unpack('>h', self.content[0:2])[0]
                self.content = self.content[2:]

                value_json = {}
                if param_type in self.api_datas:  # 复杂类型的list
                    for i in range(0, value_len):
                        self.serial2dict(param_type, value_json)
                    value.append(value_json)
                else:  # TODO 未验证
                    for i in range(0, value_len):
                        self.serial2dict(resp_id, value_json, param_name)
                        value.append(value_json["param_name"])
            else:
                value = {}
                self.serial2dict(param_type, value)

            resJson[param_name] = value
